package com.niit.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}

object Spark_SQL_UDAF1 {

  def main(args: Array[String]): Unit = {
    //计算年龄的平均值
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate();

    val df = spark.read.json("input/user.json")
    df.createOrReplaceTempView("user")

    //需要自己去定义一个计算年龄平均值的方法 UDAF
    spark.udf.register("ageAvg",functions.udaf(new MyAvgUDAF) )

    spark.sql("select ageAvg(age) from user").show()

    spark.close()
  }
  /*
  自定义聚合函数，计算年龄的平均值
    1.继承 org.apache.spark.sql.expressions.Aggregator, 定义泛型
      IN:输入的数据类型 Long
      BUF:缓冲区的数据类型 Buff
      OUT:输出的数据类型
   */
                  //  总数               个数
  case class Buff(var total:Long,var count:Long)

  class MyAvgUDAF extends Aggregator[Long,Buff,Long]{

    //初始值 零值
    override def zero: Buff ={
      Buff(0L,0L)
    }
    //根据输入的数据更新缓冲区的数据
    override def reduce(buff: Buff, in: Long): Buff = {
      buff.total = buff.total + in //1：0+30  2：0+40 3：0+50
      buff.count = buff.count + 1 //1：0+1  2：0+1 3：0+1

      buff //将计算好的buff的返回
    }
    //合并缓冲区
    override def merge(b1: Buff, b2: Buff): Buff = {
      b1.total = b1.total + b2.total
      b1.count = b1.count + b2.count

      b1//返回计算好的buff
    }
    //计算结果
    override def finish(buff: Buff): Long = {
      buff.total / buff.count
    }
    //缓冲区编码操作
    override def bufferEncoder: Encoder[Buff] = Encoders.product
    //输出编码操作
    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
  }
}
